import sys

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments

sys.path.append("src")
import argparse
import os
import random

import numpy as np
import torch
from transformers import DataCollatorForLanguageModeling

from dataset.Tofu import ToFU


def args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True, help="Model name")
    parser.add_argument(
        "--cache_dir", type=str, default=".cache", help="Cache directory"
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--epochs", type=int, default=5, help="Number of epochs to train"
    )
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
    parser.add_argument(
        "--num_warmup_steps", type=int, default=0, help="Number of warmup steps"
    )
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument(
        "--save_dir", type=str, default="files/models/tofu_opt", help="Save dir"
    )
    parser.add_argument("--subset", type=str, default="retain90", help="Subset")
    args = parser.parse_args()
    return args


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main():
    args = args_parser()
    set_seed(args.seed)

    dataset = ToFU("ToFU", subset=args.subset)
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, cache_dir=args.cache_dir, use_fast=False
    )
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "[pad]"})
    dataset, collector = dataset.build_pretrain_dataset(tokenizer, subset=args.subset)
    train_dataset = dataset
    num_devices = int(os.environ.get("WORLD_SIZE", 1))
    warmup_steps = len(train_dataset) // (args.batch_size * 4 * num_devices)
    training_args = TrainingArguments(
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=4,
        learning_rate=args.lr,
        num_train_epochs=args.epochs,
        weight_decay=0.01,
        logging_dir="logs",
        logging_steps=10,
        save_steps=10,
        warmup_steps=warmup_steps,
        # evaluation_strategy="steps",
        # eval_steps=10,
        save_total_limit=1,
        # load_best_model_at_end=True,
        # metric_for_best_model="eval_loss",
        greater_is_better=False,
        output_dir=args.save_dir,
    )
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        torch_dtype=torch.bfloat16,
        cache_dir=args.cache_dir,
        low_cpu_mem_usage=True,
        device_map="auto",
    )
    model.resize_token_embeddings(len(tokenizer))
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        # eval_dataset=test_dataset,
        tokenizer=tokenizer,
        data_collator=collector,
    )
    trainer.train()
    trainer.save_model()


if __name__ == "__main__":
    main()
